{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'\\n完整版和简化版的区别是有两个value模型\\n还有动态调整alpha\\n其他的和简化版一样'"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "\"\"\"\n",
    "完整版和简化版的区别是有两个value模型\n",
    "还有动态调整alpha\n",
    "其他的和简化版一样\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([ 0.8014121 ,  0.5981126 , -0.01938946], dtype=float32)"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import gym\n",
    "\n",
    "\n",
    "#定义环境\n",
    "class MyWrapper(gym.Wrapper):\n",
    "    def __init__(self):\n",
    "        env = gym.make('Pendulum-v1', render_mode='rgb_array')\n",
    "        super().__init__(env)\n",
    "        self.env = env\n",
    "        self.step_n = 0\n",
    "\n",
    "    def reset(self):\n",
    "        state, _ = self.env.reset()\n",
    "        self.step_n = 0\n",
    "        return state\n",
    "\n",
    "    def step(self, action):\n",
    "        state, reward, terminated, truncated, info = self.env.step(action)\n",
    "        done = terminated or truncated\n",
    "        self.step_n += 1\n",
    "        if self.step_n >= 200:\n",
    "            done = True\n",
    "        return state, reward, done, info\n",
    "\n",
    "\n",
    "env = MyWrapper()\n",
    "\n",
    "env.reset()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAakAAAGiCAYAAABd6zmYAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/P9b71AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAlF0lEQVR4nO3df3DU9YH/8ddns8kCCbshgWxISYQ7qRoRzoLCtlO1JSXatOcP2loHOc4yemJwRBznpKc4Os6E0X5rtVXsjXdi6ykt12KVil4aMGiNgJEoP6NWNBHYBMHsJoFsfuz7+4fHnougCXySfSd5PmZ2pvl83vvOez8lebr7+ezGMcYYAQBgIU+qFwAAwMkQKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtVIWqUceeUQTJ07UiBEjNHPmTG3ZsiVVSwEAWColkfrd736npUuX6u6779abb76padOmqbS0VM3NzalYDgDAUk4qPmB25syZuuCCC/SrX/1KkhSPx1VYWKibb75Zd9xxx0AvBwBgKe9Af8POzk7V1tZq2bJliW0ej0clJSWqqak54X1isZhisVji63g8rsOHDys3N1eO4/T7mgEA7jLGqLW1VQUFBfJ4Tv6i3oBH6uOPP1ZPT4+CwWDS9mAwqD179pzwPhUVFbrnnnsGYnkAgAHU2NioCRMmnHT/gEfqVCxbtkxLly5NfB2JRFRUVKTGxkb5/f4UrgwAcCqi0agKCws1evToLxw34JEaO3as0tLS1NTUlLS9qalJ+fn5J7yPz+eTz+f73Ha/30+kAGAQ+7JTNgN+dV9GRoamT5+uqqqqxLZ4PK6qqiqFQqGBXg4AwGIpeblv6dKlWrBggWbMmKELL7xQv/jFL9Te3q7rrrsuFcsBAFgqJZG6+uqrdfDgQS1fvlzhcFj/8A//oBdffPFzF1MAAIa3lLxP6nRFo1EFAgFFIhHOSQHAINTb3+N8dh8AwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAa/U5Ups2bdL3v/99FRQUyHEcPfvss0n7jTFavny5xo8fr5EjR6qkpETvvvtu0pjDhw9r3rx58vv9ys7O1sKFC9XW1nZaDwQAMPT0OVLt7e2aNm2aHnnkkRPuv//++/Xwww/rscce0+bNm5WZmanS0lJ1dHQkxsybN087d+5UZWWl1q1bp02bNumGG2449UcBABiazGmQZNauXZv4Oh6Pm/z8fPPAAw8ktrW0tBifz2eeeeYZY4wxu3btMpLM1q1bE2PWr19vHMcx+/bt69X3jUQiRpKJRCKns3wAQIr09ve4q+ek9u7dq3A4rJKSksS2QCCgmTNnqqamRpJUU1Oj7OxszZgxIzGmpKREHo9HmzdvPuG8sVhM0Wg06QYAGPpcjVQ4HJYkBYPBpO3BYDCxLxwOKy8vL2m/1+tVTk5OYszxKioqFAgEErfCwkI3lw0AsNSguLpv2bJlikQiiVtjY2OqlwQAGACuRio/P1+S1NTUlLS9qakpsS8/P1/Nzc1J+7u7u3X48OHEmOP5fD75/f6kGwBg6HM1UpMmTVJ+fr6qqqoS26LRqDZv3qxQKCRJCoVCamlpUW1tbWLMhg0bFI/HNXPmTDeXAwAY5Lx9vUNbW5vee++9xNd79+5VXV2dcnJyVFRUpCVLlui+++7T5MmTNWnSJN11110qKCjQFVdcIUk655xzdOmll+r666/XY489pq6uLi1evFg//vGPVVBQ4NoDAwAMAX29bHDjxo1G0uduCxYsMMZ8ehn6XXfdZYLBoPH5fGb27Nmmvr4+aY5Dhw6Za665xmRlZRm/32+uu+4609ra6vqliwAAO/X297hjjDEpbOQpiUajCgQCikQinJ8CgEGot7/HB8XVfQCA4YlIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCs5U31AoDTZXp61B2NqjsaVTwWkyQ5GRny+v3y+v3yePlnDgxW/PRiUOs6fFiHN21S9K23dPSDD9TV0iIZI28goJFnnKHRU6cq56KL5MvLS/VSAZwCIoVByRijI+++q31PPaW2XbtkOjuT9nd/8olaP/lEbTt3KrJ1q74yf76yiovleHiFGxhM+InFoGOMUduOHfrgl79Ua13d5wKVNLarS+27d+vDX/1KkTfekInHB3ClAE4XkcKg03X4sD564gl1fPhhr+8T279f+558UrFwuB9XBsBtRAqDionHte83v9GR997r8307GhvV+O//rnh3dz+sDEB/IFIYVNrfeUftu3ef8v3bdu3SwfXrCRUwSBApDCptO3ac1kt28Y4OHfzzn9XR0ODiqgD0FyKFYSe2f7/+VlGh7tbWVC8FwJcgUhiWOpubdeD3v1dPR0eqlwLgCxApDE/G6NBf/qKW119P9UoAfAEihUHF4/PJceljjnra29X42GNq3b5dxhhX5gTgLiKFQSXnkks06u//3rX5eo4c+fRlv2jUtTkBuIdIYVDxjh6tsaWlUlqaa3O2vv22wn/4A8+mAAsRKQw67Xv2SG5+vJExal63Tk1/+INMT4978wI4bUQKg07+D3/o2nmpY0x3tw5t3Kgj77/PMyrAIkQKg07GuHEqmDfP9Xk7Ghv10X/8h8SzKcAaRAqDjuPxKHj55SpYsEBOerqrc7ft2qX3778/8ccTAaQWkcKg5KSladycORoxYYLrc7ft3q1Pamo4PwVYgEhh0PKOHq3CG25w/dlUdySi/b/9rWJNTZyfAlKMSGFQyyou1pl33aW00aNdnbfz4EG989OfqqetzdV5AfQNkcKg5jiOss45R9mzZrk+d3c0qqbnnuP8FJBCRAqDnsfn0/irr9aIoiJX5zXd3Tr4/PNqef11/uw8kCJECkOCLy9PX73vPvny812dt+fIEe198MFT+kvAAE4fkcKQ4fX7lXflle5PHI/rwO9/r+72dvfnBvCFiBSGDMfjUe5FFyl39mzJcVydO1JbqwOrVyve1eXqvAC+GJHCkJKWmami8nIFLrzQ3Yl7etT8pz/p4Pr1vH8KGEBECkOOx+tVwY9/3C9zH3zxRcWam/tlbgCfR6QwJI2cOFGF//IvcjIyXJ039tFH+uAXv1AP56eAAUGkMCQ5aWkad+mlGv+jH7n+ientu3er4de/Vk9Hh6vzAvg8IoUhy0lL07jLLlNaVpbrc0fr6tT69tu8fwroZ0QKQ1paVla/fGxSd0uLPnjoIcX273d1XgDJiBSGNMdxNOrMM3XG4sXyBgKuzt3T2qq/VVSoKxJxdV4A/4dIYchzHEf+889X5uTJrs8da2rSocpK3j8F9BMihWEhbcQITbrtNo2cNMnVeU1np/Y/84w+eeUVzk8B/YBIYdjwjBqlv7v9do0oLHR1XtPVpQ9/+Ut17Nvn6rwAiBSGEcdx5Bs/XmO++U3X5zY9PTrwzDPqOXLE9bmB4axPkaqoqNAFF1yg0aNHKy8vT1dccYXq6+uTxnR0dKi8vFy5ubnKysrS3Llz1dTUlDSmoaFBZWVlGjVqlPLy8nT77beru7v79B8N8CWctDTl/+AHGnvppa5/vt8nf/2r9v3mN+o5etTVeYHhrE+Rqq6uVnl5uV5//XVVVlaqq6tLc+bMUftn3n1/66236vnnn9eaNWtUXV2t/fv366qrrkrs7+npUVlZmTo7O/Xaa6/pySef1KpVq7R8+XL3HhXwBTxeryb85Cfyn3++uxMbo4MvvKDIli382XnAJY45jZ+mgwcPKi8vT9XV1brooosUiUQ0btw4Pf300/rBD34gSdqzZ4/OOecc1dTUaNasWVq/fr2+973vaf/+/QoGg5Kkxx57TP/6r/+qgwcPKqMXH2MTjUYVCAQUiUTk9/tPdfkYxowxan3rLb17zz2Syx8Y6/vKV/TV++5TRm6uq/MCQ0lvf4+f1jmpyP++PyQnJ0eSVFtbq66uLpWUlCTGnH322SoqKlJNTY0kqaamRuedd14iUJJUWlqqaDSqnTt3nvD7xGIxRaPRpBtwOhzH0eipU3XGokXyjBzp6tyxffv0/ooV6jx40NV5geHolCMVj8e1ZMkSfeMb39CUKVMkSeFwWBkZGcrOzk4aGwwGFQ6HE2M+G6hj+4/tO5GKigoFAoHErdDlq7MwPDkej3K+/W3llZW5fn6qvb5eTc89p3gs5uq8wHBzypEqLy/Xjh07tHr1ajfXc0LLli1TJBJJ3BobG/v9e2J48Hi9yrv8cqVlZro+96GqKrXt2cP5KeA0nFKkFi9erHXr1mnjxo2aMGFCYnt+fr46OzvV0tKSNL6pqUn5+fmJMcdf7Xfs62Njjufz+eT3+5NugFu8fr8m33uvMvLyXJ23p61N71dU6Mh777k6LzCc9ClSxhgtXrxYa9eu1YYNGzTpuHfvT58+Xenp6aqqqkpsq6+vV0NDg0KhkCQpFApp+/btav7MH46rrKyU3+9XcXHx6TwW4JQ4jqNRkyap4NprXf/E9J4jR7TvySf5fD/gFPUpUuXl5Xrqqaf09NNPa/To0QqHwwqHwzr6v+8LCQQCWrhwoZYuXaqNGzeqtrZW1113nUKhkGbNmiVJmjNnjoqLizV//ny99dZbeumll3TnnXeqvLxcPp/P/UcI9IKTlqYxX/+6Rk6c6PrcbXv26FBVleK8FxDosz5dgu6c5OTyE088oX/+53+W9OmbeW+77TY988wzisViKi0t1aOPPpr0Ut6HH36oRYsW6eWXX1ZmZqYWLFigFStWyNvLP07HJejoLz0dHXr3rrvUftyb1N1QdNNNGltaetKfI2A46e3v8dN6n1SqECn0F2OM2t95Rx8+/LA6XL5Ax+v366yf/UwjTnLuFRhOBuR9UsBQ4ziOMs88U9mzZklpaa7O3d3aqv1PPaXu1lZX5wWGMiIFHMdJS1PBtddq7Jw57k5sjD7ZtEkfrVqlno4Od+cGhigiBZyA4zgquOYajZ42zfW5D1VWqp33TwG9QqSAk/AGAhr33e8qbdQo1+du+PWv1fmZt2EAODEiBZyE4zgaEwppwsKF8rj89ojYvn3a+/Ofq+PAAVfnBYYaIgV8iZxLLnH//JSk9t279cmrr8rw/ingpIgU8CU86ekKXnGFMsaNc33uA6tXq233btfnBYYKIgX0Qsa4cZp0223KcPk9TqarS+8/8ICib7/t6rzAUEGkgF7KPOssjb/6atfPT3W3tOjjF19Ud2srV/wBxyFSQC85aWnKveQSZZ59tutzf/Laa/r4pZekeNz1uYHBjEgBfeCkpWnS0qXKOvdcdyeOx7Xvt7/VwRdfdHdeYJAjUkAfebOzVTBvnnzjx7s7sTE6+Oc/K3bgAC/7Af+LSAF95DiOss49VzkXXeT65/t1fPSR9j31lHr4fD9AEpECTonjOBp/zTUaW1Li+tyfvPKKGh9/XKanx/W5gcGGSAGnynGU/8Mf9svn+33y6qtqq6/nZT8Me0QKOEWO4yhj3DiN++535Q0EXJ3bdHer4dFH1fHRR67OCww2RAo4Dcc+36/g2mvdPz/V0KAPHnpIsYMHXZ0XGEyIFOCC3G99S2NLS12f98g77yiyeTMv+2HYIlKAC5z0dAW/9z2NnDjR9bk/WrVKrXV1rs8LDAZECnCB4zgaMWGCCq+/Xum5ua7ObTo79cHDDyu6bRvPqDDsECnARVnnnquCefMkx3F13q5Dh3R40ybFjx4lVBhWiBTgIsfjUc7FFyswY4brcx+qqlLz+vWuzwvYjEgBLvOkp6vw+uvd/3w/SQeeflrNzz3HsykMG0QK6AcZwaAKrr1W6Tk5rs5rurrU/Pzzaq2rI1QYFogU0A8cx9Hoc8/VuMsu6/N997W3a11jo555/339Zf9+tXd1Je3vbG5W+I9/VE9bm1vLBazlTfUCgKEseOWVijU16dBf/vKlY40x2tvWpru3bdMHbW3q6OmRPz1dU8aM0c8uuEDpnv/7b8rWt97Skb/9TaOnTZPj8kUagE14JgX0Iyc9XeN/9CNlTZnypWPfb2vT9X/9q3ZHIjra0yMjKdLVpb82N+uWzZt1qKMjaXzDY4/106oBexApoB85jiNffr6Cl1+utKysLxz7i507FTnupb1jtnz8sSr370/aZk4yFhhKiBQwALJnztRX5s+XPPzIAX3BTwwwQHK+9S2N/c53Ur0MYFAhUsAASRsxQnnf+55G/t3fnXB/WWGh0k9yEcTErCxNdflydmAwIFLAABp5xhkqvP76E/79qdKCAt19/vkakZaW+MFMcxzl+nz6fxdcoOLs7KTxwblz+3/BQIpxCTowwLKKi1Uwb54aHn00abvjOCotKNCEUaO07qOPdKijQxOzsnT1pEnK9fmSxvrGj9eYUIjLzzHkESlggDmOozHf/KZa335b7e++q86DB6V4PLFvypgxmjJmzEnv7x0zRgXz58vr9w/UkoGUIVJACngzM1V0003qaW/Xgd/9Todfflmmu/tL75eWlaXxV1+t7AsvlOPyXwIGbESkgBTxZmUpLTNTExYuVHpurg5v3KjO5uYTjnXS0jSisFDBK69UziWX8DIfhg0iBaSQ4zjyZmZq/A9+IP/UqfrktdfUtnOnYuGw4rGY0rKyNGLCBAVmzFBgxgyNLCoiUBhWiBRgAY/Pp6wpU5T51a+q5+hRme5umXhcTlqaPOnp8owaJY+XH1cMP/yrByzhOI4cn0+e467kA4Yz3icFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgrT5FauXKlZo6dar8fr/8fr9CoZDWr1+f2N/R0aHy8nLl5uYqKytLc+fOVVNTU9IcDQ0NKisr06hRo5SXl6fbb79d3d3d7jwaAMCQ0qdITZgwQStWrFBtba3eeOMNffvb39bll1+unTt3SpJuvfVWPf/881qzZo2qq6u1f/9+XXXVVYn79/T0qKysTJ2dnXrttdf05JNPatWqVVq+fLm7jwoAMDSY0zRmzBjz+OOPm5aWFpOenm7WrFmT2Ld7924jydTU1BhjjHnhhReMx+Mx4XA4MWblypXG7/ebWCzW6+8ZiUSMJBOJRE53+QCAFOjt7/FTPifV09Oj1atXq729XaFQSLW1terq6lJJSUlizNlnn62ioiLV1NRIkmpqanTeeecpGAwmxpSWlioajSaejZ1ILBZTNBpNugEAhr4+R2r79u3KysqSz+fTjTfeqLVr16q4uFjhcFgZGRnKzs5OGh8MBhUOhyVJ4XA4KVDH9h/bdzIVFRUKBAKJW2FhYV+XDQAYhPocqbPOOkt1dXXavHmzFi1apAULFmjXrl39sbaEZcuWKRKJJG6NjY39+v0AAHbw9vUOGRkZOvPMMyVJ06dP19atW/XQQw/p6quvVmdnp1paWpKeTTU1NSk/P1+SlJ+fry1btiTNd+zqv2NjTsTn88nn8/V1qQCAQe603ycVj8cVi8U0ffp0paenq6qqKrGvvr5eDQ0NCoVCkqRQKKTt27erubk5MaayslJ+v1/FxcWnuxQAwBDTp2dSy5Yt02WXXaaioiK1trbq6aef1ssvv6yXXnpJgUBACxcu1NKlS5WTkyO/36+bb75ZoVBIs2bNkiTNmTNHxcXFmj9/vu6//36Fw2HdeeedKi8v55kSAOBz+hSp5uZm/dM//ZMOHDigQCCgqVOn6qWXXtJ3vvMdSdKDDz4oj8ejuXPnKhaLqbS0VI8++mji/mlpaVq3bp0WLVqkUCikzMxMLViwQPfee6+7jwoAMCQ4xhiT6kX0VTQaVSAQUCQSkd/vT/VyAAB91Nvf43x2HwDAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrnVakVqxYIcdxtGTJksS2jo4OlZeXKzc3V1lZWZo7d66ampqS7tfQ0KCysjKNGjVKeXl5uv3229Xd3X06SwEADEGnHKmtW7fq17/+taZOnZq0/dZbb9Xzzz+vNWvWqLq6Wvv379dVV12V2N/T06OysjJ1dnbqtdde05NPPqlVq1Zp+fLlp/4oAABDkzkFra2tZvLkyaaystJcfPHF5pZbbjHGGNPS0mLS09PNmjVrEmN3795tJJmamhpjjDEvvPCC8Xg8JhwOJ8asXLnS+P1+E4vFevX9I5GIkWQikcipLB8AkGK9/T1+Ss+kysvLVVZWppKSkqTttbW16urqStp+9tlnq6ioSDU1NZKkmpoanXfeeQoGg4kxpaWlikaj2rlz5wm/XywWUzQaTboBAIY+b1/vsHr1ar355pvaunXr5/aFw2FlZGQoOzs7aXswGFQ4HE6M+Wygju0/tu9EKioqdM899/R1qQCAQa5Pz6QaGxt1yy236L/+6780YsSI/lrT5yxbtkyRSCRxa2xsHLDvDQBInT5Fqra2Vs3Nzfra174mr9crr9er6upqPfzww/J6vQoGg+rs7FRLS0vS/ZqampSfny9Jys/P/9zVfse+PjbmeD6fT36/P+kGABj6+hSp2bNna/v27aqrq0vcZsyYoXnz5iX+d3p6uqqqqhL3qa+vV0NDg0KhkCQpFApp+/btam5uToyprKyU3+9XcXGxSw8LADAU9Omc1OjRozVlypSkbZmZmcrNzU1sX7hwoZYuXaqcnBz5/X7dfPPNCoVCmjVrliRpzpw5Ki4u1vz583X//fcrHA7rzjvvVHl5uXw+n0sPCwAwFPT5wokv8+CDD8rj8Wju3LmKxWIqLS3Vo48+mtiflpamdevWadGiRQqFQsrMzNSCBQt07733ur0UAMAg5xhjTKoX0VfRaFSBQECRSITzUwAwCPX29zif3QcAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsJY31Qs4FcYYSVI0Gk3xSgAAp+LY7+9jv89PZlBG6tChQ5KkwsLCFK8EAHA6WltbFQgETrp/UEYqJydHktTQ0PCFD264i0ajKiwsVGNjo/x+f6qXYy2OU+9wnHqH49Q7xhi1traqoKDgC8cNykh5PJ+eSgsEAvwj6AW/389x6gWOU+9wnHqH4/TlevMkgwsnAADWIlIAAGsNykj5fD7dfffd8vl8qV6K1ThOvcNx6h2OU+9wnNzlmC+7/g8AgBQZlM+kAADDA5ECAFiLSAEArEWkAADWGpSReuSRRzRx4kSNGDFCM2fO1JYtW1K9pAG1adMmff/731dBQYEcx9Gzzz6btN8Yo+XLl2v8+PEaOXKkSkpK9O677yaNOXz4sObNmye/36/s7GwtXLhQbW1tA/go+ldFRYUuuOACjR49Wnl5ebriiitUX1+fNKajo0Pl5eXKzc1VVlaW5s6dq6ampqQxDQ0NKisr06hRo5SXl6fbb79d3d3dA/lQ+tXKlSs1derUxBtPQ6GQ1q9fn9jPMTqxFStWyHEcLVmyJLGNY9VPzCCzevVqk5GRYf7zP//T7Ny501x//fUmOzvbNDU1pXppA+aFF14w//Zv/2b++Mc/Gklm7dq1SftXrFhhAoGAefbZZ81bb71l/vEf/9FMmjTJHD16NDHm0ksvNdOmTTOvv/66eeWVV8yZZ55prrnmmgF+JP2ntLTUPPHEE2bHjh2mrq7OfPe73zVFRUWmra0tMebGG280hYWFpqqqyrzxxhtm1qxZ5utf/3pif3d3t5kyZYopKSkx27ZtMy+88IIZO3asWbZsWSoeUr947rnnzJ///GfzzjvvmPr6evPTn/7UpKenmx07dhhjOEYnsmXLFjNx4kQzdepUc8sttyS2c6z6x6CL1IUXXmjKy8sTX/f09JiCggJTUVGRwlWlzvGRisfjJj8/3zzwwAOJbS0tLcbn85lnnnnGGGPMrl27jCSzdevWxJj169cbx3HMvn37BmztA6m5udlIMtXV1caYT49Jenq6WbNmTWLM7t27jSRTU1NjjPn0PwY8Ho8Jh8OJMStXrjR+v9/EYrGBfQADaMyYMebxxx/nGJ1Aa2urmTx5sqmsrDQXX3xxIlIcq/4zqF7u6+zsVG1trUpKShLbPB6PSkpKVFNTk8KV2WPv3r0Kh8NJxygQCGjmzJmJY1RTU6Ps7GzNmDEjMaakpEQej0ebN28e8DUPhEgkIun/Ppy4trZWXV1dScfp7LPPVlFRUdJxOu+88xQMBhNjSktLFY1GtXPnzgFc/cDo6enR6tWr1d7erlAoxDE6gfLycpWVlSUdE4l/T/1pUH3A7Mcff6yenp6k/5MlKRgMas+ePSlalV3C4bAknfAYHdsXDoeVl5eXtN/r9SonJycxZiiJx+NasmSJvvGNb2jKlCmSPj0GGRkZys7OThp7/HE60XE8tm+o2L59u0KhkDo6OpSVlaW1a9equLhYdXV1HKPPWL16td58801t3br1c/v499R/BlWkgFNRXl6uHTt26NVXX031Uqx01llnqa6uTpFIRP/93/+tBQsWqLq6OtXLskpjY6NuueUWVVZWasSIEalezrAyqF7uGzt2rNLS0j53xUxTU5Py8/NTtCq7HDsOX3SM8vPz1dzcnLS/u7tbhw8fHnLHcfHixVq3bp02btyoCRMmJLbn5+ers7NTLS0tSeOPP04nOo7H9g0VGRkZOvPMMzV9+nRVVFRo2rRpeuihhzhGn1FbW6vm5mZ97Wtfk9frldfrVXV1tR5++GF5vV4Fg0GOVT8ZVJHKyMjQ9OnTVVVVldgWj8dVVVWlUCiUwpXZY9KkScrPz086RtFoVJs3b04co1AopJaWFtXW1ibGbNiwQfF4XDNnzhzwNfcHY4wWL16stWvXasOGDZo0aVLS/unTpys9PT3pONXX16uhoSHpOG3fvj0p6JWVlfL7/SouLh6YB5IC8XhcsViMY/QZs2fP1vbt21VXV5e4zZgxQ/PmzUv8b45VP0n1lRt9tXr1auPz+cyqVavMrl27zA033GCys7OTrpgZ6lpbW822bdvMtm3bjCTz85//3Gzbts18+OGHxphPL0HPzs42f/rTn8zbb79tLr/88hNegn7++eebzZs3m1dffdVMnjx5SF2CvmjRIhMIBMzLL79sDhw4kLgdOXIkMebGG280RUVFZsOGDeaNN94woVDIhEKhxP5jlwzPmTPH1NXVmRdffNGMGzduSF0yfMcdd5jq6mqzd+9e8/bbb5s77rjDOI5j/ud//scYwzH6Ip+9us8YjlV/GXSRMsaYX/7yl6aoqMhkZGSYCy+80Lz++uupXtKA2rhxo5H0uduCBQuMMZ9ehn7XXXeZYDBofD6fmT17tqmvr0+a49ChQ+aaa64xWVlZxu/3m+uuu860tram4NH0jxMdH0nmiSeeSIw5evSouemmm8yYMWPMqFGjzJVXXmkOHDiQNM8HH3xgLrvsMjNy5EgzduxYc9ttt5murq4BfjT95yc/+Yk544wzTEZGhhk3bpyZPXt2IlDGcIy+yPGR4lj1D/5UBwDAWoPqnBQAYHghUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFr/H6fRrLnXFCUXAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from matplotlib import pyplot as plt\n",
    "\n",
    "%matplotlib inline\n",
    "\n",
    "\n",
    "#打印游戏\n",
    "def show():\n",
    "    plt.imshow(env.render())\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([[1.5660],\n",
       "         [1.6829]], grad_fn=<MulBackward0>),\n",
       " tensor([[0.9125],\n",
       "         [0.8339]], grad_fn=<NegBackward0>))"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "\n",
    "class ModelAction(torch.nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.fc_state = torch.nn.Sequential(\n",
    "            torch.nn.Linear(3, 128),\n",
    "            torch.nn.ReLU(),\n",
    "        )\n",
    "        self.fc_mu = torch.nn.Linear(128, 1)\n",
    "        self.fc_std = torch.nn.Sequential(\n",
    "            torch.nn.Linear(128, 1),\n",
    "            torch.nn.Softplus(),\n",
    "        )\n",
    "\n",
    "    def forward(self, state):\n",
    "        #[b, 3] -> [b, 128]\n",
    "        state = self.fc_state(state)\n",
    "\n",
    "        #[b, 128] -> [b, 1]\n",
    "        mu = self.fc_mu(state)\n",
    "\n",
    "        #[b, 128] -> [b, 1]\n",
    "        std = self.fc_std(state)\n",
    "\n",
    "        #根据mu和std定义b个正态分布\n",
    "        dist = torch.distributions.Normal(mu, std)\n",
    "\n",
    "        #采样b个样本\n",
    "        #这里用的是rsample,表示重采样,其实就是先从一个标准正态分布中采样,然后乘以标准差,加上均值\n",
    "        sample = dist.rsample()\n",
    "\n",
    "        #样本压缩到-1,1之间,求动作\n",
    "        action = torch.tanh(sample)\n",
    "\n",
    "        #求概率对数\n",
    "        log_prob = dist.log_prob(sample)\n",
    "\n",
    "        #这个值描述动作的熵\n",
    "        entropy = log_prob - (1 - action.tanh()**2 + 1e-7).log()\n",
    "        entropy = -entropy\n",
    "\n",
    "        return action * 2, entropy\n",
    "\n",
    "\n",
    "model_action = ModelAction()\n",
    "\n",
    "model_action(torch.randn(2, 3))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-0.1950],\n",
       "        [-0.0635]], grad_fn=<AddmmBackward0>)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class ModelValue(torch.nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.sequential = torch.nn.Sequential(\n",
    "            torch.nn.Linear(4, 128),\n",
    "            torch.nn.ReLU(),\n",
    "            torch.nn.Linear(128, 128),\n",
    "            torch.nn.ReLU(),\n",
    "            torch.nn.Linear(128, 1),\n",
    "        )\n",
    "\n",
    "    def forward(self, state, action):\n",
    "        #[b, 3+1] -> [b, 4]\n",
    "        state = torch.cat([state, action], dim=1)\n",
    "\n",
    "        #[b, 4] -> [b, 1]\n",
    "        return self.sequential(state)\n",
    "\n",
    "\n",
    "model_value1 = ModelValue()\n",
    "model_value2 = ModelValue()\n",
    "\n",
    "model_value_next1 = ModelValue()\n",
    "model_value_next2 = ModelValue()\n",
    "\n",
    "model_value_next1.load_state_dict(model_value1.state_dict())\n",
    "model_value_next2.load_state_dict(model_value2.state_dict())\n",
    "\n",
    "model_value1(torch.randn(2, 3), torch.randn(2, 1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "-0.9855514168739319"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import random\n",
    "import numpy as np\n",
    "\n",
    "\n",
    "def get_action(state):\n",
    "    state = torch.FloatTensor(state).reshape(1, 3)\n",
    "    action, _ = model_action(state)\n",
    "    return action.item()\n",
    "\n",
    "\n",
    "get_action([1, 2, 3])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(200,\n",
       " (array([ 0.9439776 , -0.33000955, -0.1769515 ], dtype=float32),\n",
       "  -0.6050740480422974,\n",
       "  -0.11660419226099256,\n",
       "  array([ 0.935164  , -0.35421515, -0.51521975], dtype=float32),\n",
       "  False))"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#样本池\n",
    "datas = []\n",
    "\n",
    "\n",
    "#向样本池中添加N条数据,删除M条最古老的数据\n",
    "def update_data():\n",
    "    #初始化游戏\n",
    "    state = env.reset()\n",
    "\n",
    "    #玩到游戏结束为止\n",
    "    over = False\n",
    "    while not over:\n",
    "        #根据当前状态得到一个动作\n",
    "        action = get_action(state)\n",
    "\n",
    "        #执行动作,得到反馈\n",
    "        next_state, reward, over, _ = env.step([action])\n",
    "\n",
    "        #记录数据样本\n",
    "        datas.append((state, action, reward, next_state, over))\n",
    "\n",
    "        #更新游戏状态,开始下一个动作\n",
    "        state = next_state\n",
    "\n",
    "    #数据上限,超出时从最古老的开始删除\n",
    "    while len(datas) > 100000:\n",
    "        datas.pop(0)\n",
    "\n",
    "\n",
    "update_data()\n",
    "\n",
    "len(datas), datas[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_2141/1710091499.py:7: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at  ../torch/csrc/utils/tensor_new.cpp:201.)\n",
      "  state = torch.FloatTensor([i[0] for i in samples]).reshape(-1, 3)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(tensor([[-0.1646, -0.9864, -5.7690],\n",
       "         [-0.5338,  0.8456,  8.0000],\n",
       "         [ 0.3142,  0.9493,  2.8880],\n",
       "         [-0.9838, -0.1794,  7.8493],\n",
       "         [-0.7075,  0.7068,  8.0000]]),\n",
       " tensor([[ 0.2345],\n",
       "         [ 1.9940],\n",
       "         [-0.5016],\n",
       "         [-1.2680],\n",
       "         [ 1.9993]]),\n",
       " tensor([[ -6.3424],\n",
       "         [-10.9575],\n",
       "         [ -2.3997],\n",
       "         [-14.9316],\n",
       "         [-11.9580]]),\n",
       " tensor([[-0.4698, -0.8828, -6.4736],\n",
       "         [-0.8210,  0.5710,  8.0000],\n",
       "         [ 0.1429,  0.9897,  3.5247],\n",
       "         [-0.8491, -0.5283,  7.5246],\n",
       "         [-0.9268,  0.3755,  8.0000]]),\n",
       " tensor([[0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0]]))"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#获取一批数据样本\n",
    "def get_sample():\n",
    "    #从样本池中采样\n",
    "    samples = random.sample(datas, 64)\n",
    "\n",
    "    #[b, 3]\n",
    "    state = torch.FloatTensor([i[0] for i in samples]).reshape(-1, 3)\n",
    "    #[b, 1]\n",
    "    action = torch.FloatTensor([i[1] for i in samples]).reshape(-1, 1)\n",
    "    #[b, 1]\n",
    "    reward = torch.FloatTensor([i[2] for i in samples]).reshape(-1, 1)\n",
    "    #[b, 3]\n",
    "    next_state = torch.FloatTensor([i[3] for i in samples]).reshape(-1, 3)\n",
    "    #[b, 1]\n",
    "    over = torch.LongTensor([i[4] for i in samples]).reshape(-1, 1)\n",
    "\n",
    "    return state, action, reward, next_state, over\n",
    "\n",
    "\n",
    "state, action, reward, next_state, over = get_sample()\n",
    "\n",
    "state[:5], action[:5], reward[:5], next_state[:5], over[:5]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "-1683.9410110535862"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from IPython import display\n",
    "\n",
    "\n",
    "def test(play):\n",
    "    #初始化游戏\n",
    "    state = env.reset()\n",
    "\n",
    "    #记录反馈值的和,这个值越大越好\n",
    "    reward_sum = 0\n",
    "\n",
    "    #玩到游戏结束为止\n",
    "    over = False\n",
    "    while not over:\n",
    "        #根据当前状态得到一个动作\n",
    "        action = get_action(state)\n",
    "\n",
    "        #执行动作,得到反馈\n",
    "        state, reward, over, _ = env.step([action])\n",
    "        reward_sum += reward\n",
    "\n",
    "        #打印动画\n",
    "        if play and random.random() < 0.2:  #跳帧\n",
    "            display.clear_output(wait=True)\n",
    "            show()\n",
    "\n",
    "    return reward_sum\n",
    "\n",
    "\n",
    "test(play=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def soft_update(model, model_next):\n",
    "    for param, param_next in zip(model.parameters(), model_next.parameters()):\n",
    "        #以一个小的比例更新\n",
    "        value = param_next.data * 0.995 + param.data * 0.005\n",
    "        param_next.data.copy_(value)\n",
    "\n",
    "\n",
    "soft_update(torch.nn.Linear(4, 64), torch.nn.Linear(4, 64))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(-4.6052, requires_grad=True)"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import math\n",
    "\n",
    "#这也是一个可学习的参数\n",
    "alpha = torch.tensor(math.log(0.01))\n",
    "alpha.requires_grad = True\n",
    "\n",
    "alpha"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([64, 1])"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def get_target(reward, next_state, over):\n",
    "    #首先使用model_action计算动作和动作的熵\n",
    "    #[b, 4] -> [b, 1],[b, 1]\n",
    "    action, entropy = model_action(next_state)\n",
    "\n",
    "    #评估next_state的价值\n",
    "    #[b, 4],[b, 1] -> [b, 1]\n",
    "    target1 = model_value_next1(next_state, action)\n",
    "    target2 = model_value_next2(next_state, action)\n",
    "\n",
    "    #取价值小的,这是出于稳定性考虑\n",
    "    #[b, 1]\n",
    "    target = torch.min(target1, target2)\n",
    "\n",
    "    #exp和log互为反操作,这里是把alpha还原了\n",
    "    #这里的操作是在target上加上了动作的熵,alpha作为权重系数\n",
    "    #[b, 1] - [b, 1] -> [b, 1]\n",
    "    target += alpha.exp() * entropy\n",
    "\n",
    "    #[b, 1]\n",
    "    target *= 0.99\n",
    "    target *= (1 - over)\n",
    "    target += reward\n",
    "\n",
    "    return target\n",
    "\n",
    "\n",
    "get_target(reward, next_state, over).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor(0.4448, grad_fn=<MeanBackward0>),\n",
       " tensor([[0.5181],\n",
       "         [1.1856],\n",
       "         [0.6098],\n",
       "         [2.2213],\n",
       "         [1.0059],\n",
       "         [1.0943],\n",
       "         [0.5175],\n",
       "         [0.5336],\n",
       "         [0.3465],\n",
       "         [1.9152],\n",
       "         [0.2851],\n",
       "         [1.9617],\n",
       "         [0.3080],\n",
       "         [0.5080],\n",
       "         [0.5073],\n",
       "         [0.6083],\n",
       "         [1.8864],\n",
       "         [1.8668],\n",
       "         [0.3378],\n",
       "         [0.4438],\n",
       "         [1.4230],\n",
       "         [2.0435],\n",
       "         [0.7549],\n",
       "         [1.4198],\n",
       "         [2.3602],\n",
       "         [0.4413],\n",
       "         [1.0770],\n",
       "         [0.9458],\n",
       "         [0.7440],\n",
       "         [1.4281],\n",
       "         [1.1989],\n",
       "         [0.2889],\n",
       "         [1.2037],\n",
       "         [0.6786],\n",
       "         [0.6629],\n",
       "         [0.4009],\n",
       "         [1.8516],\n",
       "         [1.6241],\n",
       "         [0.3535],\n",
       "         [0.5151],\n",
       "         [1.5448],\n",
       "         [0.4732],\n",
       "         [1.2274],\n",
       "         [0.7507],\n",
       "         [0.4770],\n",
       "         [0.6567],\n",
       "         [0.9315],\n",
       "         [0.5582],\n",
       "         [0.8405],\n",
       "         [0.5776],\n",
       "         [0.5886],\n",
       "         [0.4185],\n",
       "         [0.6450],\n",
       "         [0.5078],\n",
       "         [0.1775],\n",
       "         [0.3839],\n",
       "         [2.3009],\n",
       "         [0.7732],\n",
       "         [1.5574],\n",
       "         [0.4892],\n",
       "         [1.5161],\n",
       "         [1.7142],\n",
       "         [0.1019],\n",
       "         [1.6037]], grad_fn=<NegBackward0>))"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def get_loss_action(state):\n",
    "    #计算action和熵\n",
    "    #[b, 3] -> [b, 1],[b, 1]\n",
    "    action, entropy = model_action(state)\n",
    "\n",
    "    #使用两个value网络评估action的价值\n",
    "    #[b, 3],[b, 1] -> [b, 1]\n",
    "    value1 = model_value1(state, action)\n",
    "    value2 = model_value2(state, action)\n",
    "\n",
    "    #取价值小的,出于稳定性考虑\n",
    "    #[b, 1]\n",
    "    value = torch.min(value1, value2)\n",
    "\n",
    "    #alpha还原后乘以熵,这个值期望的是越大越好,但是这里是计算loss,所以符号取反\n",
    "    #[1] - [b, 1] -> [b, 1]\n",
    "    loss_action = -alpha.exp() * entropy\n",
    "\n",
    "    #减去value,所以value越大越好,这样loss就会越小\n",
    "    loss_action -= value\n",
    "\n",
    "    return loss_action.mean(), entropy\n",
    "\n",
    "\n",
    "get_loss_action(state)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "id": "OHoSU6uI-xIt",
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 400 0.00935130100697279 -1384.9634034870005\n",
      "10 2400 0.005746820010244846 -1273.817515890381\n",
      "20 4400 0.003627561964094639 -353.5943787095424\n",
      "40 8400 0.001914315391331911 -353.12730080624937\n",
      "50 10400 0.00204734830185771 -227.96972393226093\n",
      "60 12400 0.0015838450053706765 -251.13651326686494\n",
      "70 14400 0.0012645457172766328 -217.23893143587378\n",
      "80 16400 0.0014361173380166292 -166.3167507733615\n",
      "90 18400 0.0011741687776520848 -206.47951713582424\n"
     ]
    }
   ],
   "source": [
    "def train():\n",
    "    optimizer_action = torch.optim.Adam(model_action.parameters(), lr=3e-4)\n",
    "    optimizer_value1 = torch.optim.Adam(model_value1.parameters(), lr=3e-3)\n",
    "    optimizer_value2 = torch.optim.Adam(model_value2.parameters(), lr=3e-3)\n",
    "\n",
    "    #alpha也是要更新的参数,所以这里要定义优化器\n",
    "    optimizer_alpha = torch.optim.Adam([alpha], lr=3e-4)\n",
    "\n",
    "    loss_fn = torch.nn.MSELoss()\n",
    "\n",
    "    #训练N次\n",
    "    for epoch in range(100):\n",
    "        #更新N条数据\n",
    "        update_data()\n",
    "\n",
    "        #每次更新过数据后,学习N次\n",
    "        for i in range(200):\n",
    "            #采样一批数据\n",
    "            state, action, reward, next_state, over = get_sample()\n",
    "\n",
    "            #对reward偏移,为了便于训练\n",
    "            reward = (reward + 8) / 8\n",
    "\n",
    "            #计算target,这个target里已经考虑了动作的熵\n",
    "            #[b, 1]\n",
    "            target = get_target(reward, next_state, over)\n",
    "            target = target.detach()\n",
    "\n",
    "            #计算两个value\n",
    "            value1 = model_value1(state, action)\n",
    "            value2 = model_value2(state, action)\n",
    "\n",
    "            #计算两个loss,两个value的目标都是要贴近target\n",
    "            loss_value1 = loss_fn(value1, target)\n",
    "            loss_value2 = loss_fn(value2, target)\n",
    "\n",
    "            #更新参数\n",
    "            optimizer_value1.zero_grad()\n",
    "            loss_value1.backward()\n",
    "            optimizer_value1.step()\n",
    "\n",
    "            optimizer_value2.zero_grad()\n",
    "            loss_value2.backward()\n",
    "            optimizer_value2.step()\n",
    "\n",
    "            #使用model_value计算model_action的loss\n",
    "            loss_action, entropy = get_loss_action(state)\n",
    "            optimizer_action.zero_grad()\n",
    "            loss_action.backward()\n",
    "            optimizer_action.step()\n",
    "\n",
    "            #熵乘以alpha就是alpha的loss\n",
    "            #[b, 1] -> [1]\n",
    "            loss_alpha = (entropy + 1).detach() * alpha.exp()\n",
    "            loss_alpha = loss_alpha.mean()\n",
    "\n",
    "            #更新alpha值\n",
    "            optimizer_alpha.zero_grad()\n",
    "            loss_alpha.backward()\n",
    "            optimizer_alpha.step()\n",
    "\n",
    "            #增量更新next模型\n",
    "            soft_update(model_value1, model_value_next1)\n",
    "            soft_update(model_value2, model_value_next2)\n",
    "\n",
    "        if epoch % 10 == 0:\n",
    "            test_result = sum([test(play=False) for _ in range(10)]) / 10\n",
    "            print(epoch, len(datas), alpha.exp().item(), test_result)\n",
    "\n",
    "\n",
    "train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAakAAAGiCAYAAABd6zmYAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/P9b71AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAjYElEQVR4nO3dfXBc5WHv8d+utFpZL7uyZGsX1RJ2g4uja+wS29hbZkIaKxaJkkJwp4TxEIf6wuDKro07TFFrzIR2RpTcWxJaMJ3JBPNHwB2nMQQXQ1WZCBKEbYRFZBnU0JhKwV4Jv2hXkq3Vvjz3D/BeFgyRrJXOo+X7mdkZfM7ZR885yPv17p496zLGGAEAYCG30xMAAOCTECkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLUci9Qjjzyi+fPnq7CwUCtXrtShQ4ecmgoAwFKOROpf//VftW3bNt133316/fXXtXTpUtXX12tgYMCJ6QAALOVy4gKzK1eu1IoVK/TP//zPkqRUKqXq6mpt3rxZ99xzz3RPBwBgqfzp/oFjY2Pq6OhQU1NTepnb7VZdXZ3a29svep9YLKZYLJb+cyqV0pkzZ1RRUSGXyzXlcwYAZJcxRkNDQ6qqqpLb/ckv6k17pE6dOqVkMqlAIJCxPBAI6K233rrofZqbm/Xd7353OqYHAJhGfX19mjdv3ieun/ZIXYqmpiZt27Yt/edIJKKamhr19fXJ5/M5ODMAwKWIRqOqrq5WaWnpp2437ZGaM2eO8vLy1N/fn7G8v79fwWDwovfxer3yer0fW+7z+YgUAMxgv+stm2k/u6+goEDLli1Ta2trelkqlVJra6tCodB0TwcAYDFHXu7btm2b1q9fr+XLl+uaa67R97//fY2MjOi2225zYjoAAEs5Eqmbb75Z7733nnbs2KFwOKw//MM/1PPPP/+xkykAAJ9tjnxOarKi0aj8fr8ikQjvSQHADDTex3Gu3QcAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWhOO1EsvvaRvfOMbqqqqksvl0tNPP52x3hijHTt26LLLLtOsWbNUV1enX//61xnbnDlzRuvWrZPP51NZWZk2bNig4eHhSe0IACD3TDhSIyMjWrp0qR555JGLrn/wwQf18MMP67HHHtPBgwdVXFys+vp6jY6OprdZt26duru71dLSon379umll17SHXfccel7AQDITWYSJJm9e/em/5xKpUwwGDTf+9730ssGBweN1+s1Tz31lDHGmGPHjhlJ5vDhw+lt9u/fb1wul3n33XfH9XMjkYiRZCKRyGSmDwBwyHgfx7P6ntTx48cVDodVV1eXXub3+7Vy5Uq1t7dLktrb21VWVqbly5ent6mrq5Pb7dbBgwcvOm4sFlM0Gs24AQByX1YjFQ6HJUmBQCBjeSAQSK8Lh8OqrKzMWJ+fn6/y8vL0Nh/V3Nwsv9+fvlVXV2dz2gAAS82Is/uampoUiUTSt76+PqenBACYBlmNVDAYlCT19/dnLO/v70+vCwaDGhgYyFifSCR05syZ9DYf5fV65fP5Mm4AgNyX1UgtWLBAwWBQra2t6WXRaFQHDx5UKBSSJIVCIQ0ODqqjoyO9zYEDB5RKpbRy5cpsTgcAMMPlT/QOw8PDevvtt9N/Pn78uDo7O1VeXq6amhpt3bpVf//3f6+FCxdqwYIFuvfee1VVVaUbb7xRkvT5z39e119/vW6//XY99thjisfj2rRpk771rW+pqqoqazsGAMgBEz1t8MUXXzSSPnZbv369Meb909DvvfdeEwgEjNfrNatXrzY9PT0ZY5w+fdrccsstpqSkxPh8PnPbbbeZoaGhrJ+6CACw03gfx13GGONgIy9JNBqV3+9XJBLh/SkAmIHG+zg+I87uAwB8NhEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFoTvgo6AGeYVEqpWEzRN97QUGenRt99V6mxMXnKylS8aJHKrrlG3kBAysuTy+VyerpAVhApYAZIjo5q6I03dHLPHp17+20plcpYP/jqqwr/5Ceq+PKXVfn1r78fKyAHECnAciaR0KkXXlD4pz9V4uzZT9jIKDk0pIFnnlHs5En93re/rVk1NdM7UWAKECnAYiaV0qkXX9TJp55S8ty5cd0n8tprkqTq229XQWUlL/1hRuPECcBiw93dOvnkk+MOlCQplVLk0CGF/+3fpGRy6iYHTAMiBVgqFYvpzMsvK3769CXd//R//qdiAwNZnhUwvYgUYCFjjKJvvKEzbW2XPkYioXd37crepAAH8J4UYKHkyIh+8w//IBOPT2qcRDSapRkBzuCZFGCh+NmzkjFOTwNwHJECLBQ/fVqGSAFECrDRUHf3xz6wC3wWESnAQsNdXbzcB4hIATktr6TE6SkAk0KkgFzldqv8uuucngUwKUQKsEwqHpfJxvtRLpc85eWTHwdwEJECLJMcHlZqbGzS47hcLhXMmZOFGQHOIVKAZRLDw5P+EO8F+aWlWRkHcAqRAiyTGBpSKhbLyliuvLysjAM4hUgBlhk+elRjXBgWkESkAPvwIV4gjUgBFjHGKFsf4Z39xS/ych9mPCIF2MQYmUQiK0MVXX655OavOGY2foMBi6TicSWHh7MyVn55ucRXx2OGI1KARczYWNa+A8rj82VlHMBJRAqwSGJ4WKPvvpudwdxuuXgmhRmOSAEWSUajOn/8uNPTAKxBpIAc5PJ4OLMPOYFIATmodPFieS+7zOlpAJNGpABLGGOydjmkvOJiuQsKsjIW4CQiBVhk7OzZrIyTV1oqt9eblbEAJxEpwBbGKH7qVFaGyps1Sy6PJytjAU4iUoAtjNGZtrbsjOVycfo5cgKRAmxhjGLhsNOzAKxCpIAc4y4slDcQcHoaQFYQKSDH5JWUaNaCBU5PA8gKIgVYIpml08/dHo/yuW4fcgSRAiyROHtWMpP/NimXx6P80tIszAhwHpECLBE/e1YmC9/K68rPV15RURZmBDiPSAGWOPeb32QlUpLk4ssOkSP4TQYsET1yREomnZ4GYBUiBeQYrtmHXEKkgFzicmnu177m9CyArCFSgAVS8bhMll7q85SXZ2UcwAZECrBAcmQka1/TUVBRkZVxABsQKcACieFhpUZHJz+Qy6V8v3/y4wCWIFKABc7/939rbGAgO4Nx9XPkECIFWCB+5oyS5845PQ3AOkQKyCHeyy7jg7zIKfw2Aw4zWbhe3wW+pUvlys/P2niA04gU4LRUSqmxsawM5amo4JkUcgq/zYDDTCKhRDSalbE8s2dz4gRyCpECHJZKJJQYGsrKWHklJUQKOWVCkWpubtaKFStUWlqqyspK3Xjjjerp6cnYZnR0VI2NjaqoqFBJSYnWrl2r/v7+jG16e3vV0NCgoqIiVVZW6u6771YikZj83gAzUGp0VLFwOCtjudxuuYgUcsiEItXW1qbGxka9+uqramlpUTwe15o1azQyMpLe5q677tKzzz6rPXv2qK2tTSdOnNBNN92UXp9MJtXQ0KCxsTG98soreuKJJ7Rr1y7t2LEje3sFzCDJoSGNvPWW09MArOQykzi16L333lNlZaXa2tr0xS9+UZFIRHPnztWTTz6pP/3TP5UkvfXWW/r85z+v9vZ2rVq1Svv379fXv/51nThxQoFAQJL02GOP6a//+q/13nvvqWAcV3CORqPy+/2KRCLy8TXZmOHOv/OOjv3lX05+IJdLn9u+XWUrVkx+LGCKjfdxfFLvSUUiEUlS+QcXtOzo6FA8HlddXV16m0WLFqmmpkbt7e2SpPb2dl111VXpQElSfX29otGouru7L/pzYrGYotFoxg3IBcYYZesE9JLFizXr8suzNBpgh0uOVCqV0tatW3Xttddq8eLFkqRwOKyCggKVlZVlbBsIBBT+4DX3cDicEagL6y+su5jm5mb5/f70rbq6+lKnDVgnK9fsk5RfXKy8wsKsjAXY4pIj1djYqKNHj2r37t3ZnM9FNTU1KRKJpG99fX1T/jOB6RI/fTor4+QVF8tNpJBjLumj6Zs2bdK+ffv00ksvad68eenlwWBQY2NjGhwczHg21d/fr2AwmN7m0KFDGeNdOPvvwjYf5fV65fV6L2WqgPXGTp3KyjjuwkK5PJ6sjAXYYkLPpIwx2rRpk/bu3asDBw5owYIFGeuXLVsmj8ej1tbW9LKenh719vYqFApJkkKhkLq6ujTwoSs+t7S0yOfzqba2djL7AsxIkY/8o20yOP0cuWZCz6QaGxv15JNP6plnnlFpaWn6PSS/369Zs2bJ7/drw4YN2rZtm8rLy+Xz+bR582aFQiGtWrVKkrRmzRrV1tbq1ltv1YMPPqhwOKzt27ersbGRZ0v47DFG5955x+lZANaaUKR27twpSfrSl76Usfzxxx/Xd77zHUnSQw89JLfbrbVr1yoWi6m+vl6PPvpoetu8vDzt27dPGzduVCgUUnFxsdavX6/7779/cnsCfIa5PB6+kRc5aVKfk3IKn5NCrjCplN649VYlJ3lZpPzZs/W5piaVLFqUpZkBU2taPicFYHKSo6NSFv6d6MrPVz7/YEMOIlKAgxJnz8qkUpMex52fL89HPp8I5AIiBTgofvaslIVIye3mM1LISUQKcFAsHJZJJrMyFl92iFzEbzXgoMjhwzLxuNPTAKxFpIAcwId4kauIFOCQbH76o/KGG7I2FmATIgU4xMTjMln6RuqCj3yzAJAriBTgkOS5c1n7mo6CD77TDcg1RApwSHJkRMnz57MyVn5paVbGAWxDpACHjJ44obH33svOYJw4gRxFpACHxE+dUiIScXoagNWIFDDDecrL+SAvcha/2cAM51+xQm6+iw05ikgBDjCp1PtXQM8Cz+zZUl5eVsYCbEOkAAeYREKJwcGsjOWZPZuX+5Cz+M0GHGCSSSWi0ayM5S4q4uw+5CwiBTggEY3q7C9/mZWxXOLafchdRApwgEmlsna1CSCX5Ts9AQDvu9gFZ3/XM6SCykp55s6dqikBjiNSwDQzxkgfCpIxRmOplE6NjurwqVN6Z3hYbpdL//sP/kBF+Z/+V7Rg7lyu24ecRqQAByTPnZP0fqDOjo2pfWBA0Xhcyyoq9PXqauWP82y9vFmz5J41ayqnCjiKSAEOiJ8+LWOMzieT2tfXp8VlZaqrqpJ3gp93chcVKY9IIYcRKcABY6dOaSyV0jO9vVo5d64W+nxyX8IZeu6CArk8nimYIWAHIgU4YKirS8cGBzWnsFCfKy3NCNS7IyM6cuaMhuJxzS0sVGjuXBV/Sog4/Ry5jEgBDjjd3a3w+fNa5Pcr74PIGGN0fHhY9x05oneGhzWaTMrn8Wjx7Nn6PytWyPOh96my+dXzgM34nBTggHOJhEbicf1eUVH6mdBvhod1+y9/qTcjEZ1PJmUkReJx/XJgQFsOHtTpDz5XZYzRG2fO6Lwxyvf5HNwLYOoRKcABKUlGUsGHTpT4fne3IvH4Rbc/dOqUWk6ckCRF43Hteecd/XciobKVK6dhtoBziBTggJo77hj3aeYfZozRiXPndOLcOR0+e1b5fv8UzA6wB+9JAdPM5XJp1vz5Slzi+0qXFxfrS8Gg/ri6+v2v6QByGM+kAAe4XC7JGCVTqfSyhupqeT7hTL35JSVaUl4ul8ullKRSj0eVRUVyFxZO04wBZxApwAHFpaWqvPpqvReLpc/Uq6+q0n1XX63CvLz0X8w8l0sVXq/+74oVqi0rS7/c5yso0Pw/+zPndgCYJrzcBzjAV16uKxsa9JtHH1VlYWH66zbqq6o0r6hI+377W50eHdX8khLdvGCBKj74eviEMXpneFgLPvc5VV57LZ+RQs4jUoADXC6XrrnuOh1+/nkN/Pa3Cn5waSOXy6XFs2dr8Se81/SboSGdKynRFxsbOf0cnwm83Ac4ZE4goPXbt+s/8/IUHh1V6hNOpDAfvHf1djSqV0dG9O277lIwFJJrgtf5A2YiIgU4xOVyaf7v/74aH35Yh6uqdHBoSAOjo0p+KFaJVEonzp1T28CAjhYXa/Pf/Z2uuOEGuQsKHJw5MH1cZgZeXyUajcrv9ysSicjHSx6Y4YwxCp84ocMvv6zuZ57R6bfflonFlEql5M7LU/m8efpf11+v0Fe+osqFC3kGhZww3sdxIgVYIpVKKTo4qNHz59NfjOhyuTSrsFClfr/cXO0cOWS8j+OcOAFYwu12q4xv2QUy8J4UAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtSYUqZ07d2rJkiXy+Xzy+XwKhULav39/ev3o6KgaGxtVUVGhkpISrV27Vv39/Rlj9Pb2qqGhQUVFRaqsrNTdd9+tRCKRnb0BAOSUCUVq3rx5euCBB9TR0aHXXntNX/7yl3XDDTeou7tbknTXXXfp2Wef1Z49e9TW1qYTJ07opptuSt8/mUyqoaFBY2NjeuWVV/TEE09o165d2rFjR3b3CgCQG8wkzZ492/zwhz80g4ODxuPxmD179qTXvfnmm0aSaW9vN8YY89xzzxm3223C4XB6m507dxqfz2disdi4f2YkEjGSTCQSmez0AQAOGO/j+CW/J5VMJrV7926NjIwoFAqpo6ND8XhcdXV16W0WLVqkmpoatbe3S5La29t11VVXKRAIpLepr69XNBpNPxu7mFgspmg0mnEDAOS+CUeqq6tLJSUl8nq9uvPOO7V3717V1tYqHA6roKBAZWVlGdsHAgGFw2FJUjgczgjUhfUX1n2S5uZm+f3+9K26unqi0wYAzEATjtSVV16pzs5OHTx4UBs3btT69et17NixqZhbWlNTkyKRSPrW19c3pT8PAGCH/IneoaCgQFdccYUkadmyZTp8+LB+8IMf6Oabb9bY2JgGBwcznk319/crGAxKkoLBoA4dOpQx3oWz/y5sczFer1der3eiUwUAzHCT/pxUKpVSLBbTsmXL5PF41Nraml7X09Oj3t5ehUIhSVIoFFJXV5cGBgbS27S0tMjn86m2tnayUwEA5JgJPZNqamrSV7/6VdXU1GhoaEhPPvmkfv7zn+uFF16Q3+/Xhg0btG3bNpWXl8vn82nz5s0KhUJatWqVJGnNmjWqra3VrbfeqgcffFDhcFjbt29XY2Mjz5QAAB8zoUgNDAzo29/+tk6ePCm/368lS5bohRde0Fe+8hVJ0kMPPSS32621a9cqFoupvr5ejz76aPr+eXl52rdvnzZu3KhQKKTi4mKtX79e999/f3b3CgCQE1zGGOP0JCYqGo3K7/crEonI5/M5PR0AwASN93Gca/cBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsNakIvXAAw/I5XJp69at6WWjo6NqbGxURUWFSkpKtHbtWvX392fcr7e3Vw0NDSoqKlJlZaXuvvtuJRKJyUwFAJCDLjlShw8f1r/8y79oyZIlGcvvuusuPfvss9qzZ4/a2tp04sQJ3XTTTen1yWRSDQ0NGhsb0yuvvKInnnhCu3bt0o4dOy59LwAAuclcgqGhIbNw4ULT0tJirrvuOrNlyxZjjDGDg4PG4/GYPXv2pLd98803jSTT3t5ujDHmueeeM26324TD4fQ2O3fuND6fz8RisXH9/EgkYiSZSCRyKdMHADhsvI/jl/RMqrGxUQ0NDaqrq8tY3tHRoXg8nrF80aJFqqmpUXt7uySpvb1dV111lQKBQHqb+vp6RaNRdXd3X/TnxWIxRaPRjBsAIPflT/QOu3fv1uuvv67Dhw9/bF04HFZBQYHKysoylgcCAYXD4fQ2Hw7UhfUX1l1Mc3Ozvvvd7050qgCAGW5Cz6T6+vq0ZcsW/fjHP1ZhYeFUzeljmpqaFIlE0re+vr5p+9kAAOdMKFIdHR0aGBjQF77wBeXn5ys/P19tbW16+OGHlZ+fr0AgoLGxMQ0ODmbcr7+/X8FgUJIUDAY/drbfhT9f2OajvF6vfD5fxg0AkPsmFKnVq1erq6tLnZ2d6dvy5cu1bt269H97PB61tram79PT06Pe3l6FQiFJUigUUldXlwYGBtLbtLS0yOfzqba2Nku7BQDIBRN6T6q0tFSLFy/OWFZcXKyKior08g0bNmjbtm0qLy+Xz+fT5s2bFQqFtGrVKknSmjVrVFtbq1tvvVUPPvigwuGwtm/frsbGRnm93iztFgAgF0z4xInf5aGHHpLb7dbatWsVi8VUX1+vRx99NL0+Ly9P+/bt08aNGxUKhVRcXKz169fr/vvvz/ZUAAAznMsYY5yexERFo1H5/X5FIhHenwKAGWi8j+Ncuw8AYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYK18pydwKYwxkqRoNOrwTAAAl+LC4/eFx/NPMiMjdfr0aUlSdXW1wzMBAEzG0NCQ/H7/J66fkZEqLy+XJPX29n7qzn3WRaNRVVdXq6+vTz6fz+npWIvjND4cp/HhOI2PMUZDQ0Oqqqr61O1mZKTc7vffSvP7/fwSjIPP5+M4jQPHaXw4TuPDcfrdxvMkgxMnAADWIlIAAGvNyEh5vV7dd9998nq9Tk/Fahyn8eE4jQ/HaXw4TtnlMr/r/D8AABwyI59JAQA+G4gUAMBaRAoAYC0iBQCw1oyM1COPPKL58+ersLBQK1eu1KFDh5ye0rR66aWX9I1vfENVVVVyuVx6+umnM9YbY7Rjxw5ddtllmjVrlurq6vTrX/86Y5szZ85o3bp18vl8Kisr04YNGzQ8PDyNezG1mpubtWLFCpWWlqqyslI33nijenp6MrYZHR1VY2OjKioqVFJSorVr16q/vz9jm97eXjU0NKioqEiVlZW6++67lUgkpnNXptTOnTu1ZMmS9AdPQ6GQ9u/fn17PMbq4Bx54QC6XS1u3bk0v41hNETPD7N692xQUFJgf/ehHpru729x+++2mrKzM9Pf3Oz21afPcc8+Zv/3bvzU//elPjSSzd+/ejPUPPPCA8fv95umnnzZvvPGG+ZM/+ROzYMECc/78+fQ2119/vVm6dKl59dVXzcsvv2yuuOIKc8stt0zznkyd+vp68/jjj5ujR4+azs5O87Wvfc3U1NSY4eHh9DZ33nmnqa6uNq2trea1114zq1atMn/0R3+UXp9IJMzixYtNXV2dOXLkiHnuuefMnDlzTFNTkxO7NCV+9rOfmX//9383//Vf/2V6enrM3/zN3xiPx2OOHj1qjOEYXcyhQ4fM/PnzzZIlS8yWLVvSyzlWU2PGReqaa64xjY2N6T8nk0lTVVVlmpubHZyVcz4aqVQqZYLBoPne976XXjY4OGi8Xq956qmnjDHGHDt2zEgyhw8fTm+zf/9+43K5zLvvvjttc59OAwMDRpJpa2szxrx/TDwej9mzZ096mzfffNNIMu3t7caY9/8x4Ha7TTgcTm+zc+dO4/P5TCwWm94dmEazZ882P/zhDzlGFzE0NGQWLlxoWlpazHXXXZeOFMdq6syol/vGxsbU0dGhurq69DK32626ujq1t7c7ODN7HD9+XOFwOOMY+f1+rVy5Mn2M2tvbVVZWpuXLl6e3qaurk9vt1sGDB6d9ztMhEolI+v8XJ+7o6FA8Hs84TosWLVJNTU3GcbrqqqsUCATS29TX1ysajaq7u3saZz89ksmkdu/erZGREYVCIY7RRTQ2NqqhoSHjmEj8Pk2lGXWB2VOnTimZTGb8T5akQCCgt956y6FZ2SUcDkvSRY/RhXXhcFiVlZUZ6/Pz81VeXp7eJpekUilt3bpV1157rRYvXizp/WNQUFCgsrKyjG0/epwudhwvrMsVXV1dCoVCGh0dVUlJifbu3ava2lp1dnZyjD5k9+7dev3113X48OGPreP3aerMqEgBl6KxsVFHjx7VL37xC6enYqUrr7xSnZ2dikQi+slPfqL169erra3N6WlZpa+vT1u2bFFLS4sKCwudns5nyox6uW/OnDnKy8v72Bkz/f39CgaDDs3KLheOw6cdo2AwqIGBgYz1iURCZ86cybnjuGnTJu3bt08vvvii5s2bl14eDAY1NjamwcHBjO0/epwudhwvrMsVBQUFuuKKK7Rs2TI1Nzdr6dKl+sEPfsAx+pCOjg4NDAzoC1/4gvLz85Wfn6+2tjY9/PDDys/PVyAQ4FhNkRkVqYKCAi1btkytra3pZalUSq2trQqFQg7OzB4LFixQMBjMOEbRaFQHDx5MH6NQKKTBwUF1dHSktzlw4IBSqZRWrlw57XOeCsYYbdq0SXv37tWBAwe0YMGCjPXLli2Tx+PJOE49PT3q7e3NOE5dXV0ZQW9paZHP51Ntbe307IgDUqmUYrEYx+hDVq9era6uLnV2dqZvy5cv17p169L/zbGaIk6fuTFRu3fvNl6v1+zatcscO3bM3HHHHaasrCzjjJlcNzQ0ZI4cOWKOHDliJJl//Md/NEeOHDH/8z//Y4x5/xT0srIy88wzz5hf/epX5oYbbrjoKehXX321OXjwoPnFL35hFi5cmFOnoG/cuNH4/X7z85//3Jw8eTJ9O3fuXHqbO++809TU1JgDBw6Y1157zYRCIRMKhdLrL5wyvGbNGtPZ2Wmef/55M3fu3Jw6Zfiee+4xbW1t5vjx4+ZXv/qVueeee4zL5TL/8R//YYzhGH2aD5/dZwzHaqrMuEgZY8w//dM/mZqaGlNQUGCuueYa8+qrrzo9pWn14osvGkkfu61fv94Y8/5p6Pfee68JBALG6/Wa1atXm56enowxTp8+bW655RZTUlJifD6fue2228zQ0JADezM1LnZ8JJnHH388vc358+fNX/zFX5jZs2eboqIi881vftOcPHkyY5x33nnHfPWrXzWzZs0yc+bMMX/1V39l4vH4NO/N1PnzP/9zc/nll5uCggIzd+5cs3r16nSgjOEYfZqPRopjNTX4qg4AgLVm1HtSAIDPFiIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCs9f8Aj/4ccs/Bx08AAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "-376.63594003133085"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test(play=True)"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "第7章-DQN算法.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python [conda env:pt39]",
   "language": "python",
   "name": "conda-env-pt39-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
